Package com.rapidminer.operator.RatingPrediction

Source Code of com.rapidminer.operator.RatingPrediction.BiPolarSlopeOne

package com.rapidminer.operator.RatingPrediction;

import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Set;

import com.rapidminer.data.SkewSymmetricSparseMatrix;
import com.rapidminer.data.SymetricSparseMatrix_i;
import com.rapidminer.operator.Annotations;
import com.rapidminer.operator.IOObject;
import com.rapidminer.operator.Operator;
import com.rapidminer.operator.ports.OutputPort;
import com.rapidminer.operator.ports.ProcessingStep;
import com.rapidminer.tools.LogService;
import com.rapidminer.tools.LoggingHandler;

/**
Copyright (C) 2011 Zeno Gantner

*This file is originally part of MyMediaLite.

*Ported by Matej Mihelcic (Ru�er Bo�kovi� Institute) 08.08.2011
*/

public class BiPolarSlopeOne extends RatingPredictor {

     static final long serialVersionUID=3453434;
    
      private SkewSymmetricSparseMatrix  diff_matrix_like;
      private SymetricSparseMatrix_i freq_matrix_like;
      private SkewSymmetricSparseMatrix  diff_matrix_dislike;
      private SymetricSparseMatrix_i freq_matrix_dislike;

    private double global_average;
    private double[] user_average;

    ///
    public boolean CanPredict(int user_id, int item_id)
    {
      if (user_id > MaxUserID || item_id > MaxItemID)
        return false;

     
      for(int i=0;i<GetRatings().ByUser().get(user_id).size();i++){
        int index=GetRatings().ByUser().get(user_id).get(i);

        if (freq_matrix_like.getLocation(item_id, GetRatings().GetItems().get(index)) != 0)
          return true;
        if (freq_matrix_dislike.getLocation(item_id, GetRatings().GetItems().get(index)) != 0)
          return true;
      }
      return false;
    }

    ///
    public double Predict(int user_id, int item_id)
    {
      if (item_id > MaxItemID || user_id > MaxUserID){
        return global_average;
      }

      double prediction = 0.0;
      int frequencies = 0;

     
      for(int i=0;i<GetRatings().ByUser().get(user_id).size();i++){
      int index=GetRatings().ByUser().get(user_id).get(i);
       

          if (GetRatings().GetValues(index) > user_average[user_id])
          {
            int f = freq_matrix_like.getLocation(item_id, GetRatings().GetItems().get(index));
            if (f != 0)
            {
              prediction  += ( diff_matrix_like.getLocation(item_id, GetRatings().GetItems().get(index))+ GetRatings().GetValues(index) ) * f;
              frequencies += f;
            }
          }
          else
          {
            int f = freq_matrix_dislike.getLocation(item_id, GetRatings().GetItems().get(index));
            if (f != 0)
            {
              prediction  += ( diff_matrix_dislike.getLocation(item_id, GetRatings().GetItems().get(index))+ GetRatings().GetValues(index) ) * f;
              frequencies += f;
            }
          }
        }
     
      if (frequencies == 0){
        return global_average;
      }
      double result = (double) (prediction / frequencies);

      if (result > GetMaxRating()){
        return GetMaxRating();
      }
     
      if (result < GetMinRating()){
        return GetMinRating();
      }
      return result;
    }

    ///
    public void Train()
    {
      InitModel();

      // default value if no prediction can be made
      global_average = GetRatings().Average();

      // compute difference sums and frequencies
       
      Iterator<Integer> it=GetRatings().AllUsers().iterator();
     
     
      while(it.hasNext()){
      int user_id=it.next();
         
        double user_avg = 0;
       
        for(int j=0;j<GetRatings().ByUser().get(user_id).size();j++){
         
          int index=GetRatings().ByUser().get(user_id).get(j);
          user_avg+=GetRatings().GetValues(index);
         
        }
       
        user_avg /= GetRatings().ByUser().get(user_id).size();

        // store for later use
        user_average[user_id] = user_avg;

       
        for(int j=0;j<GetRatings().ByUser().get(user_id).size();j++){
          int index=GetRatings().ByUser().get(user_id).get(j);
         
          for(int k=0;k<GetRatings().ByUser().get(user_id).size();k++){
            int index2=GetRatings().ByUser().get(user_id).get(k);
           
           
            if (GetRatings().GetValues(index) > user_avg && GetRatings().GetValues(index2) > user_avg)
            {
              freq_matrix_like.setLocation(GetRatings().GetItems().get(index), GetRatings().GetItems().get(index2), freq_matrix_like.getLocation(GetRatings().GetItems().get(index), GetRatings().GetItems().get(index2))+1);
              diff_matrix_like.setLocation(GetRatings().GetItems().get(index), GetRatings().GetItems().get(index2), diff_matrix_like.getLocation(GetRatings().GetItems().get(index), GetRatings().GetItems().get(index2))+(float) (GetRatings().GetValues(index) - GetRatings().GetValues(index2)));
            }
            else if (GetRatings().GetValues(index) < user_avg && GetRatings().GetValues(index2) < user_avg)
            {
              freq_matrix_dislike.setLocation(GetRatings().GetItems().get(index), GetRatings().GetItems().get(index2), freq_matrix_dislike.getLocation(GetRatings().GetItems().get(index), GetRatings().GetItems().get(index2))+1);
              diff_matrix_dislike.setLocation(GetRatings().GetItems().get(index), GetRatings().GetItems().get(index2), diff_matrix_dislike.getLocation(GetRatings().GetItems().get(index), GetRatings().GetItems().get(index2))+(float) (GetRatings().GetValues(index) - GetRatings().GetValues(index2)));
            }
          }
        }
      }

      // compute average differences
     
      for (int i = 0; i <= MaxItemID; i++){
       
        Set<Integer> s=freq_matrix_like.Get(i).keySet();
        Iterator<Integer> it1=s.iterator();
       
        while(it1.hasNext()){
         
        int ind=it1.next();
          diff_matrix_like.setLocation(i, ind, diff_matrix_like.getLocation(i,ind)/freq_matrix_like.getLocation(i, ind));
        }
       
        s=freq_matrix_dislike.Get(i).keySet();
        it1=s.iterator();
       
       
        while(it1.hasNext()){
         
          int ind=it1.next();
            diff_matrix_dislike.setLocation(i, ind, diff_matrix_dislike.getLocation(i,ind)/freq_matrix_dislike.getLocation(i, ind));
          }
       
      }
    }

    ///
    protected void InitModel()
    {
      super.InitModel();

      // create data structure
      diff_matrix_like = new SkewSymmetricSparseMatrix(MaxItemID + 1);
      freq_matrix_like = new SymetricSparseMatrix_i(MaxItemID + 1);
      diff_matrix_dislike = new SkewSymmetricSparseMatrix(MaxItemID + 1);
      freq_matrix_dislike = new SymetricSparseMatrix_i(MaxItemID + 1);
      user_average = new double[MaxUserID + 1];
    }

   
    public void AddUsers(List<Integer> users){
      super.AddUsers(users);
     
      double[] user_average_new = new double[users.get(users.size()-1)+ 1];
     
      for(int i=0;i<user_average.length;i++)
        user_average_new[i]=user_average[i];
     
      user_average=user_average_new; 
    }
   
    public void AddItems(List<Integer> items){
      super.AddItems(items);
    }
   
   
    public int AddRatings(List<Integer> users, List<Integer> items, List<Double> ratings){
   
      if(users==null)
        return 1;
     
      super.AddRatings(users, items, ratings);
      global_average = GetRatings().Average();

      // compute difference sums and frequencies
       
      for(int k1=0;k1<users.size();k1++){
      int user_id=users.get(k1);
         
        double user_avg = 0;
       
        for(int j=0;j<GetRatings().ByUser().get(user_id).size();j++){
         
          int index=GetRatings().ByUser().get(user_id).get(j);
          user_avg+=GetRatings().GetValues(index);
         
        }
       
        user_avg /= GetRatings().ByUser().get(user_id).size();

        // store for later use
        user_average[user_id] = user_avg;

       
        for(int j=0;j<GetRatings().ByUser().get(user_id).size();j++){
          int index=GetRatings().ByUser().get(user_id).get(j);
         
           if(GetRatings().GetItems().get(index)==items.get(k1))
          for(int k=0;k<GetRatings().ByUser().get(user_id).size();k++){
            int index2=GetRatings().ByUser().get(user_id).get(k);
           
            if (GetRatings().GetValues(index) > user_avg && GetRatings().GetValues(index2) > user_avg)
            {
              freq_matrix_like.setLocation(GetRatings().GetItems().get(index), GetRatings().GetItems().get(index2), freq_matrix_like.getLocation(GetRatings().GetItems().get(index), GetRatings().GetItems().get(index2))+1);
              diff_matrix_like.setLocation(GetRatings().GetItems().get(index), GetRatings().GetItems().get(index2), diff_matrix_like.getLocation(GetRatings().GetItems().get(index), GetRatings().GetItems().get(index2))+(float) (GetRatings().GetValues(index) - GetRatings().GetValues(index2)));
            }
            else if (GetRatings().GetValues(index) < user_avg && GetRatings().GetValues(index2) < user_avg)
            {
              freq_matrix_dislike.setLocation(GetRatings().GetItems().get(index), GetRatings().GetItems().get(index2), freq_matrix_dislike.getLocation(GetRatings().GetItems().get(index), GetRatings().GetItems().get(index2))+1);
              diff_matrix_dislike.setLocation(GetRatings().GetItems().get(index), GetRatings().GetItems().get(index2), diff_matrix_dislike.getLocation(GetRatings().GetItems().get(index), GetRatings().GetItems().get(index2))+(float) (GetRatings().GetValues(index) - GetRatings().GetValues(index2)));
            }
          }
        }
      }

      // compute average differences
     
      for (int i = 0; i <= MaxItemID; i++){
       
        Set<Integer> s=freq_matrix_like.Get(i).keySet();
        Iterator<Integer> it1=s.iterator();
       
        while(it1.hasNext()){
         
        int ind=it1.next();
          diff_matrix_like.setLocation(i, ind, diff_matrix_like.getLocation(i,ind)/freq_matrix_like.getLocation(i, ind));
        }
       
        s=freq_matrix_dislike.Get(i).keySet();
        it1=s.iterator();
       
       
        while(it1.hasNext()){
         
          int ind=it1.next();
            diff_matrix_dislike.setLocation(i, ind, diff_matrix_dislike.getLocation(i,ind)/freq_matrix_dislike.getLocation(i, ind));
          }
       
      }
     
      return 1;
     
    }
   
    public void RetrainItems(List<Integer> items){
      super.RetrainItems(items);
     
    }
   
    public void RetrainUsers(List<Integer> users){
      super.RetrainUsers(users);
 
   
    ///
    public void LoadModel(String file)
    {
      //not needed
    }

    ///
    public void SaveModel(String file)
    {
      //not needed
    }

    ///
    public String ToString()
    {
       return "BipolarSlopeOne";
    }
   
   
      private String source = null;
       
        /** The current working operator. */
        private transient LoggingHandler loggingHandler;
       
        private transient LinkedList<ProcessingStep> processingHistory = new LinkedList<ProcessingStep>();
       
        /** Sets the source of this IOObject. */
        public void setSource(String sourceName) {
            this.source = sourceName;
        }

        /** Returns the source of this IOObject (might return null if the source is unknown). */
        public String getSource() {
            return source;
        }
       
        @Override
        public void appendOperatorToHistory(Operator operator, OutputPort port) {
          if (processingHistory == null) {
            processingHistory = new LinkedList<ProcessingStep>();
          if (operator.getProcess() != null)
            processingHistory.add(new ProcessingStep(operator, port));
        }
          ProcessingStep newStep = new ProcessingStep(operator, port);
          if (operator.getProcess() != null && (processingHistory.isEmpty() || !processingHistory.getLast().equals(newStep))) {
            processingHistory.add(newStep);
          }
        }
       
        @Override
        public List<ProcessingStep> getProcessingHistory() {
          if (processingHistory == null)
            processingHistory = new LinkedList<ProcessingStep>();
          return processingHistory;
        }
       
        /** Gets the logging associated with the operator currently working on this
         *  IOObject or the global log service if no operator was set. */
        public LoggingHandler getLog() {
            if (this.loggingHandler != null) {
                return this.loggingHandler;
            } else {
                return LogService.getGlobal();
            }
        }
       
        /** Sets the current working operator, i.e. the operator which is currently
         *  working on this IOObject. This might be used for example for logging. */
        public void setLoggingHandler(LoggingHandler loggingHandler) {
            this.loggingHandler = loggingHandler;
        }
       
      /**
       * Returns not a copy but the very same object. This is ok for IOObjects
       * which cannot be altered after creation. However, IOObjects which might be
       * changed (e.g. {@link com.rapidminer.example.ExampleSet}s) should
       * overwrite this method and return a proper copy.
       */
      public IOObject copy() {
        return this;
      }
     
      protected void initWriting() {}

   
      public Annotations getAnnotations(){
        Annotations temp=new Annotations();
        return temp;
      }
   
   
  }
TOP

Related Classes of com.rapidminer.operator.RatingPrediction.BiPolarSlopeOne

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.